package database

import (
	"errors"
	"fmt"
	"gitee.com/zhucheer/orange/cfg"
	"gitee.com/zhucheer/orange/internal"
	"gitee.com/zhucheer/orange/logger"
	"github.com/zhuCheer/pool"
	"gorm.io/driver/clickhouse"
	"gorm.io/driver/mysql"
	"gorm.io/driver/postgres"
	"gorm.io/driver/sqlserver"
	"gorm.io/gorm"
	"sync"
	"time"
)

// gorm包目前已经支持多种数据库连接，通过该方式实现一个原生gorm联系可以使用gorm连接各类数据库

const (
	MySQL      = "mysql"
	PostgreSQL = "postgres"
	SQLite     = "sqlite"
	SQLServer  = "sqlserver"
	Clickhouse = "clickhouse"
)

var gormConn *GormDB

type GormDB struct {
	dsn      string
	connPool map[string]pool.Pool
	count    int
	lock     sync.Mutex
}

// NewGorm 初始化 gorm 连接
func NewGorm() DataBase {
	if gormConn != nil {
		return gormConn
	}

	gormConn = &GormDB{
		connPool: make(map[string]pool.Pool, 0),
	}
	return gormConn
}

// 注册所有已配置的db
func (my *GormDB) RegisterAll() {
	databaseConfig := cfg.Config.GetMap("database.gorm")

	my.count = len(databaseConfig)
	for dd := range databaseConfig {
		my.Register(dd)
	}
}

// Register 注册一个db配置
func (my *GormDB) Register(name string) {
	dsn := cfg.Config.GetString("database.gorm." + name + ".dsn")
	dbType := cfg.Config.GetString("database.gorm." + name + ".dbType")

	initCap := getDBIntConfig("gorm", name, "initCap")
	maxCap := getDBIntConfig("gorm", name, "maxCap")
	idleTimeout := getDBIntConfig("gorm", name, "idleTimeout")
	isDebug := getBoolConfig("gorm", name, "debug")

	if initCap == 0 || maxCap == 0 || idleTimeout == 0 {
		logger.Error("database config is error initCap,maxCap,idleTimeout should be gt 0")
		return
	}
	my.dsn = dsn

	// connGorm 建立连接
	connGorm := func() (interface{}, error) {
		var dialector gorm.Dialector
		switch dbType {
		case PostgreSQL:
			dialector = postgres.Open(dsn)
		//case SQLite:
		//	dialector = sqlite.Open(dsn)
		case SQLServer:
			dialector = sqlserver.Open(dsn)
		case Clickhouse:
			dialector = clickhouse.Open(dsn)
		default:
			dialector = mysql.Open(dsn)
		}

		db, err := gorm.Open(dialector, &gorm.Config{})
		if isDebug {
			db = db.Debug()
		}
		return db, err
	}

	// closeGorm关闭连接
	closeGorm := func(v interface{}) error {
		db, _ := v.(*gorm.DB).DB()
		return db.Close()
	}

	// pingGorm 检测连接连通性
	pingGorm := func(v interface{}) error {
		db, _ := v.(*gorm.DB).DB()
		return db.Ping()
	}

	//创建一个连接池： 初始化5，最大连接30
	p, err := pool.NewChannelPool(&pool.Config{
		InitialCap: initCap,
		MaxCap:     maxCap,
		Factory:    connGorm,
		Close:      closeGorm,
		Ping:       pingGorm,
		//连接最大空闲时间，超过该时间的连接 将会关闭，可避免空闲时连接EOF，自动失效的问题
		IdleTimeout: time.Duration(idleTimeout) * time.Second,
	})
	if err != nil {
		logger.Error("register mysql conn [%s] error:%v", name, err)
		return
	}
	my.insertPool(name, p)
}

// insertPool 将连接池插入map,支持多个不同gorm链接
func (my *GormDB) insertPool(name string, p pool.Pool) {
	if my.connPool == nil {
		my.connPool = make(map[string]pool.Pool, 0)
	}

	my.lock.Lock()
	defer my.lock.Unlock()
	my.connPool[name] = p

	hideDsn := hideTcpDsnPasswordLog(my.dsn)
	internal.ConsoleLog(fmt.Sprintf("create Gorm pool [%s](%v) success", name, hideDsn))
}

// getDB 从连接池获取一个连接
func (my *GormDB) getDB(name string) (conn interface{}, put func(), err error) {
	put = func() {}

	if _, ok := my.connPool[name]; !ok {
		return nil, put, errors.New("no Gorm connect")
	}

	conn, err = my.connPool[name].Get()
	if err != nil {
		return nil, put, errors.New(fmt.Sprintf("Gorm get connect err:%v", err))
	}

	put = func() {
		my.connPool[name].Put(conn)
	}

	return conn, put, nil
}

// putDB 将连接放回连接池
func (my *GormDB) putDB(name string, db interface{}) (err error) {
	if _, ok := my.connPool[name]; !ok {
		return errors.New("no Gorm connect")
	}
	err = my.connPool[name].Put(db)

	return
}

//  GetGorm 获取一个 gorm 连接
func GetGorm(name string) (db *gorm.DB, put func(), err error) {
	put = func() {}
	if gormConn == nil {
		return nil, put, errors.New("db connect is nil")
	}

	conn, put, err := gormConn.getDB(name)
	if err != nil {
		return nil, put, err
	}
	db = conn.(*gorm.DB)
	return db, put, nil
}
